(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,[可由此下載]
(https://colab.research.google.com/drive/1tIu9KwFqp7dZ_vCLOZiQ0NK_Y0av0vGF?usp=sharing)
老頭已經介紹了 JAX 提供的 5 個控制流程,事實上,除了這 5 個之外,JAX 還有 map 和 associative_scan 另外兩個控制流程。因為它們語法和語義相當單純,老頭就不在這裏多談了,有興趣的讀者可以直接去看 JAX 的官方文件 [25.1]。要注意的是 map,JAX 官網建議我們儘量使用 vmap,老頭很快的就會為大家介紹 vmap。
最後用一張表做為控制流程的總結 [25.2]:
先看 jit 那一欄。Python 的 if 是不建議使用在 JIT 函式內,而 for 及 while 可以被有限度的使用,但是要注意:
再來看 grad 這一欄。grad 是 JAX 的 Auto Diff 功能的 API 名稱,這一欄說明了這些指令及運算子是否能夠支援 Auto Diff。雖然老頭還沒有介紹 Auto Diff,但是以下幾點,仍希望大家先放在心裏,未來在 Auto Diff 的說明中,會有更清楚的解釋。
那麼 fori_loop 在那些條件下才能支援逆向模式呢?其實 JAX 在處理 fori_loop 時,會將其轉成 while_loop 或是 scan,當它被轉成 while_loop 時,就只支援順向模式了。若是它被轉成 scan ,那麼就可以同時支援逆向及順向模式了。
fori_loop 被轉成 scan 的條件是:當其被追踪時,若是迴圈被執行的次數能夠被決定,就可以被轉成 scan。
例如:
fori_loop(1, 10, my_body, my_argument)
這行指令在被追踪時,已經可以知道它的重覆次數是 9 次,所以可以轉成 scan。
而下面這行指令:
fori_loop(x, y, my_body, my_argument)
在追踪時,x 和 y 都會用「追踪物件」來表示,並不會直接參考它們的值,也就是說,迴圈的重覆次數不能在追踪時被決定,因此 JAX 會用 while_loop 來實踐這行指令。
對於控制流程的介紹就在這告一段落,接下來,我們就要進入 Auto Diff。
註:
[25.1] map 可參考 這裏,associative_scan 可參考這裏 。
[25.2] 參考 JAX 官網文件 (https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#summary)